{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "46a8cc1a-995f-4efa-8924-aa7cdd8812f1",
   "metadata": {},
   "source": [
    "# Natural SQL 7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5beee445-bd5e-4435-b7c9-e2774e9e3f0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install transformers==4.35.2 accelerate sqlparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3206ac23-c2f1-40d9-9cf5-9aa7d9bea0f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"chatdb/natural-sql-7b\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"chatdb/natural-sql-7b\",\n",
    "    device_map=\"auto\",\n",
    "    torch_dtype=torch.float16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ae0310e4-f358-4159-818d-fdb15d78f551",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: Show me the day with the most users joining\n",
      "SQL: \n",
      "\n",
      "     SELECT created_at::date AS join_date, COUNT(*) AS user_count\n",
      "     FROM users\n",
      "     GROUP BY join_date\n",
      "     ORDER BY user_count DESC\n",
      "     LIMIT 1;\n",
      "Question: Show me the project that has a task with the most comments\n",
      "SQL: \n",
      "\n",
      "     SELECT p.project_id, p.project_name, COUNT(c.comment_id) AS comment_count\n",
      "     FROM projects p\n",
      "     JOIN tasks t ON p.project_id = t.project_id\n",
      "     JOIN comments c ON t.task_id = c.task_id\n",
      "     GROUP BY p.project_id\n",
      "     ORDER BY comment_count DESC\n",
      "     LIMIT 1;\n",
      "Question: What is the ratio of users with gmail addresses vs without?\n",
      "SQL: \n",
      "\n",
      "     SELECT\n",
      "        SUM(CASE WHEN email LIKE '%@gmail.com%' THEN 1 ELSE 0 END) AS gmail_users,\n",
      "        SUM(CASE WHEN email NOT LIKE '%@gmail.com%' THEN 1 ELSE 0 END) AS non_gmail_users,\n",
      "        (SUM(CASE WHEN email LIKE '%@gmail.com%' THEN 1 ELSE 0 END)::FLOAT / NULLIF(SUM(CASE WHEN email NOT LIKE '%@gmail.com%' THEN 1 ELSE 0 END), 0)) AS gmail_ratio\n",
      "    FROM users;\n"
     ]
    }
   ],
   "source": [
    "questions = ['Show me the day with the most users joining', 'Show me the project that has a task with the most comments', 'What is the ratio of users with gmail addresses vs without?']\n",
    "\n",
    "for question in questions:\n",
    "    prompt = f\"\"\"\n",
    "    ### Task \n",
    "\n",
    "    Generate a SQL query to answer the following question: `{question}` \n",
    "    \n",
    "    ### PostgreSQL Database Schema \n",
    "    The query will run on a database with the following schema: \n",
    "    ```\n",
    "    CREATE TABLE users (\n",
    "        user_id SERIAL PRIMARY KEY,\n",
    "        username VARCHAR(50) NOT NULL,\n",
    "        email VARCHAR(100) NOT NULL,\n",
    "        password_hash TEXT NOT NULL,\n",
    "        created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP\n",
    "    );\n",
    "    \n",
    "    CREATE TABLE projects (\n",
    "        project_id SERIAL PRIMARY KEY,\n",
    "        project_name VARCHAR(100) NOT NULL,\n",
    "        description TEXT,\n",
    "        start_date DATE,\n",
    "        end_date DATE,\n",
    "        owner_id INTEGER REFERENCES users(user_id)\n",
    "    );\n",
    "    \n",
    "    CREATE TABLE tasks (\n",
    "        task_id SERIAL PRIMARY KEY,\n",
    "        task_name VARCHAR(100) NOT NULL,\n",
    "        description TEXT,\n",
    "        due_date DATE,\n",
    "        status VARCHAR(50),\n",
    "        project_id INTEGER REFERENCES projects(project_id)\n",
    "    );\n",
    "    \n",
    "    CREATE TABLE taskassignments (\n",
    "        assignment_id SERIAL PRIMARY KEY,\n",
    "        task_id INTEGER REFERENCES tasks(task_id),\n",
    "        user_id INTEGER REFERENCES users(user_id),\n",
    "        assigned_date DATE NOT NULL DEFAULT CURRENT_TIMESTAMP\n",
    "    );\n",
    "    \n",
    "    CREATE TABLE comments (\n",
    "        comment_id SERIAL PRIMARY KEY,\n",
    "        content TEXT NOT NULL,\n",
    "        created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,\n",
    "        task_id INTEGER REFERENCES tasks(task_id),\n",
    "        user_id INTEGER REFERENCES users(user_id)\n",
    "    );\n",
    "    ```\n",
    "    \n",
    "    ### Answer \n",
    "    Here is the SQL query that answers the question: `{question}` \n",
    "    ```sql\n",
    "    \"\"\"\n",
    "\n",
    "    print (\"Question: \" + question)\n",
    "    print (\"SQL: \")\n",
    "    \n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "    generated_ids = model.generate(\n",
    "        **inputs,\n",
    "        num_return_sequences=1,\n",
    "        eos_token_id=100001,\n",
    "        pad_token_id=100001,\n",
    "        max_new_tokens=400,\n",
    "        do_sample=False,\n",
    "        num_beams=1,\n",
    "    \n",
    "    )\n",
    "    \n",
    "    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "    print(outputs[0].split(\"```sql\")[-1])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
